-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Improve shape handling in generate_samples #3456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… This is now handled in generate_samples.
I added a really small code to use |
I think this is ready for review again |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a couple nits, and will merge after work if no one has any objections!
pymc3/distributions/distribution.py
Outdated
slicer_tail = [slice(None)] * len(sp_shape) | ||
broadcasted_samples.append(param[tuple(slicer_head + slicer_tail)]) | ||
return np.broadcast_arrays(*broadcasted_samples) | ||
return np.asarray(samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
newline
pymc3/distributions/distribution.py
Outdated
size_tup=size_tup, | ||
dist_shape=dist_shape, | ||
broadcast_shape=broadcast_shape, | ||
test=broadcast_shape[:len(size_tup)] == size_tup) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i like informative exceptions and am fine leaving it here, just want to flag that it is named test
, so maybe you intended to remove it?
pymc3/distributions/shape_utils.py
Outdated
@@ -0,0 +1,369 @@ | |||
#!/usr/bin/env python3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we don't typically include shebangs (#!
) in files
Thanks @lucianopaz! Looks like no one will ever have a shape problem with PyMC3 again 😀 (this is seriously good, careful, and useful work, though!) |
@ColCarroll, one can always dream 🤞 |
This PR fixes #3421 and #3147.
Both errors happened because of how
distributions.distribution.generate_samples
handled the supplied: parameter shapes,broadcast_shape
anddist_shape
. To make this fix more robust, I wrote a small module ofshape
broadcasting utilities that are now used bygenerate_samples
and make it more robust to shape problems. I added a test suite for these functions and also forsample_prior_predictive
with certain RV shapes and sample sizes that caused #3421 and #3147.Issue #3422 remains open because that problem happens in
distributions.distribution._draw_value
and until it is fixed, some of the tests in the testsuite are expected to fail.Things I would like to ask:
generate_samples
function. Could you review the new version to see if you spot problems?